{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_lNeCgAVkdhM"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "uDcWxmG9kh1Q"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UhoPXG8g1CBt"
   },
   "source": [
    "# WordCount in TFF\n",
    "\n",
    "This notebook demonstrates a basic analytics query (count the top occurrences of each word in Shakespeare), implemented first as pure-Tensorflow and then as a TFF computation.\n",
    "\n",
    "The goal is to demonstrate an analytics query in TFFs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 68
    },
    "colab_type": "code",
    "id": "Ary-OZz5jMJI",
    "outputId": "b495b0da-43fa-41a1-f411-da51c9596850"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/usr/bin/sh: pip: command not found\n",
      "/usr/bin/sh: pip: command not found\n",
      "/usr/bin/sh: pip: command not found\n"
     ]
    }
   ],
   "source": [
    "#@test {\"skip\": true}\n",
    "\n",
    "# Note: If you are running a Jupyter notebook, and installing a locally built\n",
    "# pip package, you may need to edit the following to point to the '.whl' file\n",
    "# on your local filesystem.\n",
    "\n",
    "!pip install --quiet --upgrade tensorflow_federated\n",
    "!pip install --quiet --upgrade tf-nightly\n",
    "!pip install --quiet --upgrade tensorflow-text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "f6JYwcvYBkGl"
   },
   "outputs": [],
   "source": [
    "import tensorflow.google as tf\n",
    "import tensorflow_federated as tff\n",
    "import tensorflow_text as text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "GPbOsTO5Bwu1"
   },
   "outputs": [],
   "source": [
    "# https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/shakespeare/load_data\n",
    "shake_train, shake_test = tff.simulation.datasets.shakespeare.load_data()\n",
    "\n",
    "training_ds = shake_train.create_tf_dataset_from_all_clients()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "J3uvbmRCB0QA"
   },
   "source": [
    "# Dataset Pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 68
    },
    "colab_type": "code",
    "id": "ik3DNHAc1b7j",
    "outputId": "e2881add-2b51-4742-c352-a3bed364879d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b'frown'\n",
      "b'pages'\n",
      "b'marcus'\n"
     ]
    }
   ],
   "source": [
    "## Preprocess the dataset to split each line into individual words\n",
    "TRIM_TO_LINES=1000\n",
    "\n",
    "tokenizer = text.UnicodeScriptTokenizer()\n",
    "def preprocess(ds):\n",
    "  # Trim dataset to improve example performance\n",
    "  ds = ds.take(TRIM_TO_LINES)\n",
    "  ds = ds.map(lambda l: tf.expand_dims(l['snippets'], 0))\n",
    "  ds = ds.flat_map(lambda l: tf.data.Dataset.from_tensor_slices(\n",
    "      tokenizer.tokenize(l)[0]))\n",
    "  ds = ds.map(text.case_fold_utf8)\n",
    "  ds = ds.filter(lambda w: tf.math.logical_not(text.wordshape(w, text.WordShape.IS_PUNCT_OR_SYMBOL)))\n",
    "  ds = ds.shuffle(buffer_size=50000)\n",
    "  return ds\n",
    " \n",
    "\n",
    "dataset = preprocess(training_ds)\n",
    "for i in dataset.take(3):\n",
    "  print(i.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 51
    },
    "colab_type": "code",
    "id": "iyv4sCQA9mgU",
    "outputId": "258ae3a4-f6ef-45e0-f640-a9931d1935b9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final vocab length 3889\n",
      "[b'road', b'and', b'call']\n"
     ]
    }
   ],
   "source": [
    "## Build a vocab dictionary by getting the set of unique words\n",
    "\n",
    "vocab = dataset.apply(tf.data.experimental.unique())\n",
    "vocab_list = [t.numpy() for t in vocab]\n",
    "print('Final vocab length %d' % len(vocab_list))\n",
    "print(vocab_list[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UAGhFahDB7fm"
   },
   "source": [
    "# Pure Tensorflow implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 561
    },
    "colab_type": "code",
    "id": "uqk1HRZ2jUtO",
    "outputId": "675bda29-04a9-4cf8-b7fe-8485a95bfc75"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocab size: 3889\n",
      "...\n",
      "the: 736\n",
      "and: 677\n",
      "i: 592\n",
      "to: 575\n",
      "of: 422\n",
      "you: 364\n",
      "d: 338\n",
      "a: 319\n",
      "my: 310\n",
      "that: 303\n",
      "in: 293\n",
      "not: 267\n",
      "is: 251\n",
      "he: 250\n",
      "s: 238\n",
      "me: 232\n",
      "it: 231\n",
      "him: 219\n",
      "with: 211\n",
      "have: 199\n",
      "his: 194\n",
      "be: 191\n",
      "thou: 185\n",
      "for: 185\n",
      "we: 183\n",
      "this: 174\n",
      "as: 170\n",
      "your: 166\n",
      "but: 166\n",
      "so: 151\n"
     ]
    }
   ],
   "source": [
    "# BATCH_SIZE = 5\n",
    "# TAKE = 2\n",
    "# K = 5\n",
    "BATCH_SIZE = 10000\n",
    "TAKE = -1\n",
    "K = 30\n",
    "\n",
    "vocab_lookup = tf.lookup.index_table_from_tensor(vocab_list)\n",
    "vocab_size = tf.cast(vocab_lookup.size(), tf.int32)\n",
    "print('Vocab size: %d' % vocab_size.numpy())\n",
    "\n",
    "counts = tf.zeros([vocab_size])\n",
    "for batch in dataset.batch(BATCH_SIZE).take(TAKE):\n",
    "  indices = vocab_lookup.lookup(batch)\n",
    "  onehot = tf.one_hot(indices, depth=vocab_size)\n",
    "  counts += tf.reduce_sum(onehot, axis=0)\n",
    "  top_vals, top_indices = tf.math.top_k(counts, k=K)\n",
    "  top_words = tf.gather(vocab_list, top_indices)\n",
    "  print('.', end='')\n",
    "\n",
    "print()\n",
    "for word,count in zip(top_words, top_vals):\n",
    "  print('%s: %d' % (word.numpy().decode('utf-8'), count))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2mqUEgAQCqjt"
   },
   "source": [
    "# Tensorflow Federated Approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 85
    },
    "colab_type": "code",
    "id": "SGyFDlLx0K7e",
    "outputId": "15813a5f-a5f2-4a0e-b4ca-2d9061f638f9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num clients: 25\n",
      "[b'within' b'is' b'i']\n",
      "[b'a' b'reports' b'is']\n",
      "[b'octavia' b'how' b'widower']\n"
     ]
    }
   ],
   "source": [
    "## Dataset prep\n",
    "TRIM_TO_CLIENTS = 25\n",
    "\n",
    "client_ids = shake_train.client_ids[:TRIM_TO_CLIENTS]\n",
    "client_datasets = [preprocess(shake_train.create_tf_dataset_for_client(id)) for id in client_ids]\n",
    "\n",
    "print('Num clients: %d' % len(client_datasets))\n",
    "for ds in client_datasets[:3]:\n",
    "  for words in ds.batch(3).take(1):\n",
    "    print(words.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 544
    },
    "colab_type": "code",
    "id": "PKdD2gW3Cz_9",
    "outputId": "46deb4bd-294d-4530-bd82-9b390ace4576"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".........................\n",
      "the: 614\n",
      "and: 549\n",
      "to: 480\n",
      "i: 435\n",
      "of: 357\n",
      "you: 316\n",
      "d: 278\n",
      "a: 264\n",
      "in: 243\n",
      "that: 237\n",
      "my: 232\n",
      "he: 211\n",
      "not: 203\n",
      "is: 187\n",
      "it: 186\n",
      "with: 179\n",
      "him: 175\n",
      "s: 171\n",
      "we: 165\n",
      "me: 163\n",
      "his: 163\n",
      "for: 161\n",
      "have: 160\n",
      "your: 149\n",
      "be: 145\n",
      "this: 142\n",
      "as: 140\n",
      "thou: 129\n",
      "but: 127\n",
      "so: 119\n"
     ]
    }
   ],
   "source": [
    "## Initial decomposition to map-reduce style, not actually TFF just yet!\n",
    "@tf.function\n",
    "def client_map_step(ds):\n",
    "  # N.B. vocab_size and vocab_lookup must be created inside the @tf.function\n",
    "  vocab_size = len(vocab_list)\n",
    "  vocab_lookup = tf.lookup.index_table_from_tensor(vocab_list)\n",
    "  \n",
    "  @tf.function\n",
    "  def _count_words_in_batch(acummulator, batch):    \n",
    "    indices = vocab_lookup.lookup(batch)\n",
    "    onehot = tf.one_hot(indices, depth=tf.cast(vocab_size, tf.int32), dtype=tf.int32)\n",
    "    return acummulator + tf.reduce_sum(onehot, axis=0)\n",
    "\n",
    "  return ds.batch(BATCH_SIZE).take(TAKE).reduce(\n",
    "      initial_state=tf.zeros([vocab_size], tf.int32),\n",
    "      reduce_func=_count_words_in_batch)\n",
    "\n",
    "@tf.function\n",
    "def cross_client_reduce_step(client_aggregates):\n",
    "  reduced = tf.math.add_n(client_aggregates)\n",
    "  top_vals, top_indices = tf.math.top_k(reduced, k=K)\n",
    "  top_words = tf.gather(vocab_list, top_indices)\n",
    "  return top_words, top_vals\n",
    "\n",
    "# Wire it all together\n",
    "client_sums = list()\n",
    "for client_ds in client_datasets:\n",
    "  print('.', end='')\n",
    "  client_sums.append(client_map_step(client_ds))\n",
    "top_words, top_counts = cross_client_reduce_step(client_sums)\n",
    "\n",
    "\n",
    "print()\n",
    "for word,count in zip(top_words, top_vals):\n",
    "  print('%s: %d' % (word.numpy().decode('utf-8'), count))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 136
    },
    "colab_type": "code",
    "id": "Qg53eovx3kZQ",
    "outputId": "284fbb2d-6131-4966-eb07-c851f9c59f9e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(string* -> int32[3889])\n",
      "{int32[3889]}@CLIENTS\n",
      "( -> int32[3889])\n",
      "(<int32[3889],int32[3889]> -> int32[3889])\n",
      "(int32[3889] -> <string[30],int32[30]>)\n",
      "<string[30],int32[30]>@SERVER\n",
      "({string*}@CLIENTS -> <string[30],int32[30]>@SERVER)\n"
     ]
    }
   ],
   "source": [
    "@tff.federated_computation(\n",
    "  tff.FederatedType((tff.SequenceType(tf.string)), tff.CLIENTS))\n",
    "def federated_top_k_words(client_datasets):\n",
    "  tff_map = tff.tf_computation(\n",
    "      client_map_step, client_datasets.type_signature.member)\n",
    "  print(tff_map.type_signature)  # (string* -> int32[VOCAB_SIZE])\n",
    "\n",
    "  client_aggregates = tff.federated_map(tff_map, client_datasets)\n",
    "  print(client_aggregates.type_signature)  # {int32[VOCAB_SIZE]@CLIENTS}\n",
    "\n",
    "  @tff.tf_computation()\n",
    "  def build_zeros():\n",
    "    return tf.zeros([len(vocab_list)], tf.int32)\n",
    "  print(build_zeros.type_signature)  # ( -> int32[VOCAB_SIZE])\n",
    "\n",
    "  @tff.tf_computation(tff_map.type_signature.result,\n",
    "                      tff_map.type_signature.result)\n",
    "  def accumulate(accum, delta):\n",
    "    return accum + delta\n",
    "  print(accumulate.type_signature)  # (<int32[VOCAB_SIZE],int32[VOCAB_SIZE]> -> int32[VOCAB_SIZE])\n",
    "\n",
    "  @tff.tf_computation(accumulate.type_signature.result)\n",
    "  def report(accum):\n",
    "    top_vals, top_indices = tf.math.top_k(accum, k=K)\n",
    "    top_words = tf.gather(vocab_list, top_indices)\n",
    "    return top_words, top_vals\n",
    "  print(report.type_signature)  # (int32[VOCAB_SIZE] -> <string[K],int32[]>)\n",
    "\n",
    "  aggregate = tff.federated_aggregate(\n",
    "      value=client_aggregates,\n",
    "      zero=build_zeros(),\n",
    "      accumulate=accumulate,\n",
    "      merge=accumulate,\n",
    "      report=report,\n",
    "  )\n",
    "  print(aggregate.type_signature)  # <string[K],int32[K]>@SERVER\n",
    "\n",
    "  return aggregate \n",
    "\n",
    "print(federated_top_k_words.type_signature)  # ({string*}@CLIENTS -> <string[K],int32[K]>@SERVER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {
     "height": 527
    },
    "colab_type": "code",
    "id": "C2C2xhR86MkI",
    "outputId": "1bf81cc1-2a76-49ef-e11a-37c158078f23"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the: 614\n",
      "and: 549\n",
      "to: 480\n",
      "i: 435\n",
      "of: 357\n",
      "you: 316\n",
      "d: 278\n",
      "a: 264\n",
      "in: 243\n",
      "that: 237\n",
      "my: 232\n",
      "he: 211\n",
      "not: 203\n",
      "is: 187\n",
      "it: 186\n",
      "with: 179\n",
      "him: 175\n",
      "s: 171\n",
      "we: 165\n",
      "me: 163\n",
      "his: 163\n",
      "for: 161\n",
      "have: 160\n",
      "your: 149\n",
      "be: 145\n",
      "this: 142\n",
      "as: 140\n",
      "thou: 129\n",
      "but: 127\n",
      "so: 119\n"
     ]
    }
   ],
   "source": [
    "top_words, top_vals = federated_top_k_words([ds for ds in client_datasets])\n",
    "for word,count in zip(top_words, top_vals):\n",
    "  print('%s: %d' % (word.decode('utf-8'), count))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "WordCount in TFF",
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
